{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://raw.githubusercontent.com/Qinbf/tf-model-zoo/master/README_IMG/01.jpg)\n",
    "AI MOOC： **www.ai-xlab.com**  \n",
    "如果你也是AI爱好者，可以添加我的微信一起交流：**sdxxqbf**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import operator\n",
    "\n",
    "# 已知分类的数据\n",
    "x1 = np.array([3,2,1])\n",
    "y1 = np.array([104,100,81])\n",
    "x2 = np.array([101,99,98])\n",
    "y2 = np.array([10,5,2])\n",
    "scatter1 = plt.scatter(x1,y1,c='r')\n",
    "scatter2 = plt.scatter(x2,y2,c='b')\n",
    "\n",
    "# 未知数据\n",
    "x = np.array([18])\n",
    "y = np.array([90])\n",
    "scatter3 = plt.scatter(x,y,c='k')\n",
    "\n",
    "\n",
    "#画图例\n",
    "plt.legend(handles=[scatter1,scatter2,scatter3],labels=['labelA','labelB','X'],loc='best')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 已知分类的数据\n",
    "x_data = np.array([[3,104],\n",
    "                   [2,100],\n",
    "                   [1,81],\n",
    "                   [101,10],\n",
    "                   [99,5],\n",
    "                   [81,2]])\n",
    "y_data = np.array(['A','A','A','B','B','B'])\n",
    "x_test = np.array([18,90])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算样本数量\n",
    "x_data_size = x_data.shape[0]\n",
    "x_data_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[18, 90],\n",
       "       [18, 90],\n",
       "       [18, 90],\n",
       "       [18, 90],\n",
       "       [18, 90],\n",
       "       [18, 90]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 复制x_test\n",
    "np.tile(x_test, (x_data_size,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 15, -14],\n",
       "       [ 16, -10],\n",
       "       [ 17,   9],\n",
       "       [-83,  80],\n",
       "       [-81,  85],\n",
       "       [-63,  88]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算x_test与每一个样本的差值\n",
    "diffMat = np.tile(x_test, (x_data_size,1)) - x_data\n",
    "diffMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 225,  196],\n",
       "       [ 256,  100],\n",
       "       [ 289,   81],\n",
       "       [6889, 6400],\n",
       "       [6561, 7225],\n",
       "       [3969, 7744]], dtype=int32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算差值的平方\n",
    "sqDiffMat = diffMat**2\n",
    "sqDiffMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  421,   356,   370, 13289, 13786, 11713], dtype=int32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 求和\n",
    "sqDistances = sqDiffMat.sum(axis=1)\n",
    "sqDistances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 20.51828453,  18.86796226,  19.23538406, 115.27792503,\n",
       "       117.41379817, 108.2266141 ])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 开方\n",
    "distances = sqDistances**0.5\n",
    "distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 2, 0, 5, 3, 4], dtype=int64)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 从小到大排序\n",
    "sortedDistances = distances.argsort()\n",
    "sortedDistances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "classCount = {}\n",
    "# 设置k\n",
    "k = 5\n",
    "for i in range(k):\n",
    "    # 获取标签\n",
    "    votelabel = y_data[sortedDistances[i]]\n",
    "    # 统计标签数量\n",
    "    classCount[votelabel] = classCount.get(votelabel,0) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'A': 3, 'B': 2}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classCount"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('A', 3), ('B', 2)]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 根据operator.itemgetter(1)-第1个值对classCount排序，然后再取倒序\n",
    "sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)\n",
    "sortedClassCount"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'A'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取数量最多的标签\n",
    "knnclass = sortedClassCount[0][0]\n",
    "knnclass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
